import math
import warnings
from torch.optim.lr_scheduler import _LRScheduler


class ExpLR(_LRScheduler):
    """ Pytorch 没有实现 lr = lr * exp(gamma) 的衰减方式
        所以继承 StepLR 类改写一下
    
    inputs:
        optimizer: 模型所用的优化器
        step_size: 更新学习率的步长
        gamma: 自行定义的超参数
    """
    def __init__(self, optimizer, step_size, gamma, last_epoch=-1, verbose=False):
        self.step_size = step_size
        self.gamma = gamma
        super(ExpLR, self).__init__(optimizer, last_epoch, verbose)
    
    # 参考 StepLR 实现
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * math.exp(self.gamma) for group in self.optimizer.param_groups]
    
    # 参考 StepLR 实现
    def _get_closed_form_lr(self):
        print(self.base_lrs, self.gamma)
        return [base_lr * math.exp(self.gamma) for base_lr in self.base_lrs]

